### © Aaron Gilkison and Maciej Kurzynski 
### Vectors of Violence: Legitimation and Distribution of State Power in the ''People’s Liberation Army Daily'' (Jiefangjun Bao)
### The Journal of Cultural Analytics

import os
import re
import json
import numpy as np
import random
import pickle
import torch
import math
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
from scipy.stats import fisher_exact
from collections import Counter, defaultdict
from tqdm import tqdm, trange
from transformers import BertTokenizer, BertModel, BertForSequenceClassification, AdamW
from sklearn.cluster import KMeans
from torch.utils.data import DataLoader
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.feature_extraction.text import CountVectorizer
from matplotlib.colors import ListedColormap, LinearSegmentedColormap

fprop = plt.font_manager.FontProperties(fname = 'STHeiti Medium.ttc', size=12)

start_year = 1956
end_year = 1989

with open("data/JFJB_1956-1989_split_sent.json", "r", encoding="utf-8") as infile:
  JFJB = json.load(infile)

stopwords = [word.strip() for word in open("stopwords_zh.txt", "r").readlines()]
stopwords.extend(["照片","图片","其他","其它"])

# ------------------------
# ------------------------    
# TABLE 1
# ------------------------
# ------------------------

def make_vocab(texts, min_freq=10):
    vocab = {}
    for text in tqdm(texts):
        text = text.split()
        for word in text:
            if word in vocab:
                vocab[word] += 1
            else:
                vocab[word] = 1

    vocab = sorted(vocab.items(), key=lambda x: x[1], reverse=True)

    vocab = [word for word, freq in vocab if freq >= min_freq]
    return vocab

def co_occurrence_matrix(texts, vocab, window_size=10):
    word2id = {word: idx for idx, word in enumerate(vocab)}
    vocab_set = set(vocab)  # faster lookups
    assert len(vocab_set) == len(word2id)
    
    cooc_counts = defaultdict(int)
    
    for text in tqdm(texts):
        words = text.split()
        
        for i, word in enumerate(words):
            if word not in vocab_set:
                continue

            left_boundary = max(0, i - window_size)
            right_boundary = min(len(words), i + window_size + 1)

            for j in range(left_boundary, right_boundary):
                if j != i and words[j] in vocab_set:
                    word_a, word_b = sorted([word2id[word], word2id[words[j]]])
                    cooc_counts[(word_a, word_b)] += 1

    vocab_size = len(vocab)
    matrix = np.zeros((vocab_size, vocab_size), dtype=np.int32)
    for (word_a, word_b), count in cooc_counts.items():
        matrix[word_a, word_b] = count
        matrix[word_b, word_a] = count  # it's symmetric

    return matrix

# TABLE 1 OPTION 1: data for MDWs for documents
texts = []
for article in tqdm(JFJB):
    texts.extend(article["text_split"])

# TABLE 1 OPTION 2: data for MDWs for sentences
texts = []
for article in tqdm(JFJB):
    texts.extend(article["sentences"])

vectorizer = CountVectorizer(tokenizer=lambda x: x.split(), min_df=100)
X = vectorizer.fit_transform(texts)
cooc_matrix = X.T.dot(X)

w2id = {word:index for (index, word) in enumerate(vectorizer.get_feature_names())}
id2w = {index:word for (index, word) in enumerate(vectorizer.get_feature_names())}

# TABLE 1 OPTION 3: data for MDWs for windows:
texts = []
for article in tqdm(JFJB):
    texts.extend(article["text_split"])
JFJB_vocab = make_vocab(texts, 10)
cooc_matrix = co_occurrence_matrix(texts, JFJB_vocab, window_size=5)
w2id = {word: idx for idx, word in enumerate(JFJB_vocab)}
id2w = {idx: word for idx, word in enumerate(JFJB_vocab)}

# calculate MDWs
total_count = cooc_matrix.sum()
context_probs = cooc_matrix.sum(axis=0) / total_count
context_probs = np.asarray(context_probs).flatten()

targets = ["战士","军人","民兵"]
MDWs = {}
for target in targets:
    MDWs[target] = []
    collocates_count = int(cooc_matrix[w2id[target]].sum()) # the whole column includes all collocates for a given term, including itself
    for collocate in tqdm(w2id):
        n_obs = int(cooc_matrix[w2id[collocate], w2id[target]])
        coocc_prob = n_obs / total_count
        try:
            denominator = context_probs[w2id[collocate]] * context_probs[w2id[target]]
            numerator = coocc_prob ** 2 # we are using PMI 2
            if denominator != 0 and numerator != 0:
                pmi = math.log(numerator / denominator)
            else:
                pmi = 0
        except Exception as e:
            print(e)
            continue
        
        count_x = n_obs
        expected_x = round(context_probs[w2id[collocate]] * collocates_count)
        count_non_x = collocates_count - count_x
        expected_non_x = round(collocates_count - expected_x)
        
        try:
            _, p_value = fisher_exact([[count_x, expected_x],[count_non_x,expected_non_x]], alternative='greater')
            if p_value < 0.01:
                MDWs[target].append((collocate, n_obs, expected_x, round(n_obs/expected_x, 3), p_value, pmi))
        except:
            print(f"Problem with calculating p value for {collocate}. Skipping...")
            continue

    MDWs[target] = sorted(MDWs[target], key=lambda x: -x[3])


# ------------------------
# ------------------------    
# FIGURE 1
# ------------------------
# ------------------------

years = list(range(start_year, end_year+1))
year_id_dict = {}
year_totals = defaultdict(int)

# calculate the total number of words per year
for index, article in enumerate(tqdm(JFJB)):
    curr_year = int(article["issue_date"].split("-")[0])
    if curr_year in years:
        year_id_dict.setdefault(curr_year, []).append(index)
        year_totals[curr_year] += len(article["text_split"].split())

terms = ["战士", "军人", "民兵"]
terms_freqs = {term: [0] * len(years) for term in terms}

for year_index, year in enumerate(tqdm(years)):
    for article_id in year_id_dict.get(year, []):
        text = JFJB[article_id]["text_split"].split()
        for term in terms:
            terms_freqs[term][year_index] += text.count(term)

terms_freqs = {term: np.array(freqs) for term, freqs in terms_freqs.items()}

colors = sns.color_palette('Set2').as_hex()
colors = ["steelblue", "orange", "red"]
fig, ax = plt.subplots(figsize=(12, 6))
labels = ["战士 $zhanshi$", "军人 $junren$", "民兵 $minbing$"]

relative_freqs = np.array(list(terms_freqs.values())) / np.array(list(year_totals.values()))
plt.stackplot(years, relative_freqs, labels=labels, colors=colors)

plt.legend(loc="lower right", prop=fprop)
plt.xlabel("Year", fontsize=14)
plt.ylabel("Term Frequency", fontsize=14)
plt.grid(True, which="both", linestyle="--", linewidth=0.5)
plt.xlim((1956, 1989))
displayable_xticks = [1956, 1960, 1965, 1970, 1975, 1980, 1985, 1989]
ax.set_xticks(displayable_xticks)
ax.set_xticklabels(displayable_xticks)

plt.tight_layout()
plt.savefig("term_frequencies.png", format="png", dpi=300, bbox_inches='tight')


# ------------------------
# ------------------------
# FIGURE 2
# ------------------------
# ------------------------

inputs = []
min_length = 100 # 100 words
max_length = 510 # 512 characters minus two for special tokens in bert tokenizer

for index, article in enumerate(tqdm(JFJB)): # split into chunks ready for bert
  if len(article["text_split"].split()) > min_length:
    text = article["text"]
    text = re.sub(r'\s+', '', text)
    text = re.sub("图片／照片／其它", "", text)
    text = text.split("＠＠")[0]
    chunks = [text[i:i+max_length] for i in range(0, len(text), max_length)]
    inputs.extend([{"JFJB_id": index, "text": chunk} for chunk in chunks if len(chunk) > min_length * 2])

violent_terms = [term.strip() for term in list(set(open("violent_terms.txt", "r").readlines())) if len(term.strip()) > 0]

def is_violent(passage, violent_vocab, min_unique_terms = 5):
    violent_char_count = sum(passage.count(word)*len(word) for word in violent_vocab)
    unique_violent_word_count = sum(passage.count(word) > 0 for word in violent_vocab)
    return violent_char_count > (len(passage) / 10) and unique_violent_word_count >= min_unique_terms

def is_nonviolent(passage, violent_vocab):
    violent_char_count = sum(passage.count(word)*len(word) for word in violent_vocab)
    return violent_char_count == 0

violent_passages = [{"input_index": input_index, "JFJB_index": input["JFJB_id"], "text": input["text"]} for input_index, input in enumerate(tqdm(inputs)) if is_violent(input["text"], violent_terms)]

violent_text_counts = {year:0 for year in range(start_year, end_year+1)}
violent_text_id_by_year = {year:[] for year in range(start_year, end_year+1)}
for passage in violent_passages:
  year = int(JFJB[passage["JFJB_index"]]["issue_date"].split("-")[0])
  violent_text_counts[year] += 1
  violent_text_id_by_year[year].append(passage["input_index"])

violent_passages_sampled = []
for year in violent_text_id_by_year.keys():
  k = 100 if len(violent_text_id_by_year[year]) > 100 else len(violent_text_id_by_year[year]) # take max 100 violent articles or fewer per year
  sampled_violent_ids = random.sample(violent_text_id_by_year[year], k=k)
  for input_id in sampled_violent_ids:
    violent_passages_sampled.append(inputs[input_id]["text"])

non_violent_passages = random.sample(inputs, k = len(violent_passages)*20) # reduces computation
non_violent_passages = random.sample([input["text"] for input in tqdm(non_violent_passages) if is_nonviolent(input["text"], violent_terms)], k = len(violent_passages_sampled))

finetuning_passages = {"violent": violent_passages_sampled, "non_violent": non_violent_passages}

def get_embeddings(articles, model, batch_size=64, max_length=512, file=None, save_every_batches=200):
    model.eval()
    embeddings = []
    if file: # a makeshift solution to restart the process from a checkpoint; colab blocks unattended execution pretty often
        if os.path.exists(file):
          with open(file, "rb") as f:
            embeddings = pickle.load(f)
          print(f"Embeddings loaded from file: {file}. Current length: {len(embeddings)}.")

    batch_index = 0
    for i in trange(len(embeddings), len(articles), batch_size):
        batch = articles[i:i+batch_size]
        inputs = tokenizer(batch, return_tensors="pt", truncation=True, padding='max_length', max_length=512).to(device)

        with torch.no_grad():
            outputs = model.bert(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"])

        cls_embeddings = outputs.last_hidden_state[:, 0, :].cpu().numpy()

        embeddings.extend(cls_embeddings)
        if batch_index % save_every_batches == 0 and batch_index != 0:
            with open(file, "wb") as f:
              pickle.dump(embeddings, f)
              print(f"Saved to file: {file}. Currently there are {len(embeddings)} embeddings.")
        batch_index += 1

    return embeddings

def get_highest_similarity(embedding, prototypes):
    similarities = cosine_similarity([embedding], prototypes)
    return similarities.max()

# FINETUNING BERT-BASE-CHINESE

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
model = BertForSequenceClassification.from_pretrained('bert-base-chinese', num_labels=2).to(device)

texts = finetuning_passages["violent"] + finetuning_passages["non_violent"]
labels = [1] * len(finetuning_passages["violent"]) + [0] * len(finetuning_passages["non_violent"])  # 0 for non-violent, 1 for violent

assert len(texts) == len(labels)

combined = list(zip(texts, labels))
random.shuffle(combined)
texts[:], labels[:] = zip(*combined)

encoding = tokenizer(texts, return_tensors='pt', padding='max_length', truncation=True, max_length=512)
inputs_bert = encoding['input_ids']
attention_masks = encoding['attention_mask']
labels = torch.tensor(labels)

# Dataloader
dataset = torch.utils.data.TensorDataset(inputs_bert, attention_masks, labels)
dataloader = DataLoader(dataset, batch_size=16, shuffle=True)

# Optimizer
optimizer = AdamW(model.parameters(), lr=2e-5)

# Training loop, just one epoch, and we might do early stopping if needed
num_epochs = 1
for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    batch_iterator = tqdm(dataloader, desc=f"Training, Epoch {epoch+1}")
    for batch_index, batch in enumerate(batch_iterator):
        batch_inputs, batch_masks, batch_labels = batch
        model.zero_grad()
        outputs = model(batch_inputs.to(device), attention_mask=batch_masks.to(device), labels=batch_labels.to(device))
        loss = outputs[0]
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        batch_iterator.set_postfix(loss=loss.item())

        if batch_index % 20 == 0 and batch_index != 0: 
            # checking the similarities for violent and non-violent, we want violent approximate 1 and non-violent approximate 0
            model.eval()
            embeddings_violent = get_embeddings(finetuning_passages["violent"], model)
            embeddings_nonviolent = get_embeddings(finetuning_passages["non_violent"], model)

            kmeans = KMeans(n_clusters=3) # trying different values
            kmeans.fit(embeddings_violent)
            prototypes = kmeans.cluster_centers_

            highest_similarities_violent = [get_highest_similarity(embed, prototypes) for embed in tqdm(embeddings_violent)]
            highest_similarities_nonviolent = [get_highest_similarity(embed, prototypes) for embed in tqdm(embeddings_nonviolent)]
            print(sum(highest_similarities_violent)/len(highest_similarities_violent))
            print(sum(highest_similarities_nonviolent)/len(highest_similarities_nonviolent))
            model.train()
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {total_loss/len(dataloader)}")

# elbow method
model.eval()
embeddings_violent = get_embeddings(finetuning_passages["violent"], model)
sum_of_squared_distances = []
K = range(1,6)
for k in tqdm(K):
    km = KMeans(n_clusters=k)
    km = km.fit(embeddings_violent)
    sum_of_squared_distances.append(km.inertia_)

plt.plot(K, sum_of_squared_distances, 'bx-')
plt.xlabel('k')
plt.ylabel('Sum of squared distances')
plt.title('Elbow Method For Optimal k')
plt.show()

kmeans = KMeans(n_clusters=3)
kmeans.fit(embeddings_violent)
prototypes = kmeans.cluster_centers_

# get embeddings for all bert-ready inputs
embeddings_JFJB = get_embeddings([input["text"] for input in inputs],
                                 model=model,
                                 batch_size=64,
                                 file="JFJB_embeddings_Feb29_2024_v1.pkl")

highest_similarities = [get_highest_similarity(embed, prototypes) for embed in tqdm(embeddings_JFJB)]

years = range(start_year, end_year+1)
months = range(1, 13)
valences = {}
for year in years:
  valences[year] = {}
  for month in months:
    valences[year][month] = []
for index, input in enumerate(inputs):
  date = JFJB[input["JFJB_id"]]["issue_date"].split("-")
  year = int(date[0])
  month = int(date[1])
  valences[year][month].append(highest_similarities[index])
for year in years:
  for month in months:
    valences[year][month] = sum(valences[year][month])/len(valences[year][month])

all_dates = ["{}-{:02}".format(year, month) for year in years for month in months] + ['1990-01']  # Added extra date

all_valences = [valences[year][month] for year in years for month in months] + [None]  # Added dummy value

plt.figure(figsize=(12, 6))
plt.plot(all_dates, all_valences, color='red', linewidth=2)
plt.xlabel('Date', fontsize=14)
plt.ylabel('Cosine similarity', fontsize=14)
xtick_labels = [str(year) for year in range(1956, 1990)] + ['1990']
plt.xticks(all_dates[::12], xtick_labels, rotation=45, fontsize=12)
plt.yticks(fontsize=12)
plt.grid(True, which='both', linestyle='--', linewidth=0.5)
plt.xlim(-6, len(all_dates) - 1 + 6)
plt.tight_layout()
plt.savefig("cosine_similarity_over_time.png", format="png", dpi=300)


# ------------------------
# ------------------------
# FIGURE 3
# ------------------------
# ------------------------

# load topic model data 
def load_topic_distributions_with_jfjb_id(filepath):
    topic_distributions = {}
    with open(filepath, 'r') as file:
        lines = [line.strip() for line in file.readlines()]
        for line in tqdm(lines):
            parts = line.strip().split()  # space-separated values
            if len(parts) > 3:  # ensure there are enough parts for ID, JFJB_id, and at least one topic
                jfjb_id = int(parts[1])
                probabilities = [float(prob) for prob in parts[2:]]  # Skip doc_id and JFJB_id
                topic_distributions[jfjb_id] = probabilities
    return topic_distributions

doc_topics_filepath = 'mallet/doc-topics.txt'  # Update this path
topic_distributions = load_topic_distributions_with_jfjb_id(doc_topics_filepath)

def load_top_terms(topic_keys_filepath, num_terms=3):
    top_terms = {}
    with open(topic_keys_filepath, 'r') as file:
        for line in file:
            parts = line.strip().split()  # split by spaces; its topic_ID, topic_size, term1, term2, term3...
            if len(parts) > 2:
                topic_index = int(parts[0])
                terms = parts[2:2+num_terms]
                top_terms[topic_index] = terms
    return top_terms

topic_keys_filepath = 'mallet/topic-keys.txt'
top_terms_per_topic = load_top_terms(topic_keys_filepath)
violence_distributions = {}
for input_index, similarity in enumerate(tqdm(highest_similarities)):
  jfjb_id = inputs[input_index]["JFJB_id"]
  if jfjb_id not in violence_distributions.keys():
    violence_distributions[jfjb_id] = []
  violence_distributions[jfjb_id].append(similarity)

for jfjb_index, violence_values in tqdm(violence_distributions.items()):
  if isinstance(violence_values, list):
    violence_distributions[jfjb_index] = sum(violence_distributions[jfjb_index])/len(violence_distributions[jfjb_index])
  else:
    violence_distributions[jfjb_index] = violence_distributions[jfjb_index][0]

# This will store the total weighted scores for each topic and month
topic_month_scores = defaultdict(lambda: defaultdict(float))
# This will store the total weights for normalization
topic_month_weights = defaultdict(lambda: defaultdict(float))

for JFJB_id, dist in tqdm(topic_distributions.items()): # for each article in JFJB that the LDA model considers
    date = JFJB[JFJB_id]["issue_date"].split("-")
    year = int(date[0])
    month = int(date[1])

    if JFJB_id not in violence_distributions:
      continue

    for topic_id, topic_weight in enumerate(dist): # for each topic
        topic_month_scores[topic_id][(year, month)] += topic_weight * violence_distributions[JFJB_id]
        topic_month_weights[topic_id][(year, month)] += topic_weight

# Normalize the scores
for topic_id in topic_month_scores:
    for year, month in topic_month_scores[topic_id]:
        topic_month_scores[topic_id][(year, month)] /= topic_month_weights[topic_id][(year, month)]


# three filter options
filter = 0.3 # whether to include only the topics that were violent enough

filter = sorted(list(set([3, 8, 17, 33, 42, 50, 59, 64, 78, 95, 103, 104, 133, 87, 122, 12, 155, 126, 152, 172, 177, 189, 194, 197, 2, 15, 60, 111, 140])))

filter = None

if filter:
  if isinstance(filter, float):
    average_scores = {}
    for topic_id in topic_month_scores:
      total_score = sum(topic_month_scores[topic_id].values())
      count = len(topic_month_scores[topic_id].values())
      average_scores[topic_id] = total_score / count
    filtered_topics = sorted([topic_id for topic_id, avg_score in average_scores.items() if avg_score >= filter])
    num_filtered_topics = len(filtered_topics)
    heatmap_data = np.zeros((num_filtered_topics, len(years) * 12))
    for idx, topic_id in enumerate(filtered_topics):
      for year_month, score in topic_month_scores[topic_id].items():
          year, month = year_month
          heatmap_data[idx][(year - 1956) * 12 + month - 1] = score
    topic_labels = ["{}: {}".format(tid, ", ".join(top_terms_per_topic[tid])) for tid in filtered_topics]
  elif isinstance(filter, list):
    filtered_topics = filter
    num_filtered_topics = len(filtered_topics)
    heatmap_data = np.zeros((num_filtered_topics, len(years) * 12))
    for idx, topic_id in enumerate(filtered_topics):
      for year_month, score in topic_month_scores[topic_id].items():
          year, month = year_month
          heatmap_data[idx][(year - 1956) * 12 + month - 1] = score
    topic_labels = ["{}: {}".format(tid, ", ".join(top_terms_per_topic[tid])) for tid in filtered_topics]
else: # no filter
  num_topics = len(list(topic_distributions.values())[0])
  heatmap_data = np.zeros((num_topics, len(years) * 12))
  for topic_id in sorted(topic_month_scores):
    for year_month, score in topic_month_scores[topic_id].items():
        year, month = year_month
        heatmap_data[topic_id][(year - 1956) * 12 + month - 1] = score
  topic_labels = ["{}: {}".format(tid, ", ".join(top_terms_per_topic[tid])) for tid in sorted(topic_month_scores)]

if filter:
  plt.figure(figsize=(12, 0.3 * num_filtered_topics))
else:
  plt.figure(figsize=(12, 0.3 * num_topics))

colors = ["white", "white", "white", "#FF0000", "black"]
cmap = LinearSegmentedColormap.from_list("custom_diverging", colors, N=256)

xticks = [(year-1956)*12 for year in years]
xticklabels = [str(year) for year in years]

ax = sns.heatmap(heatmap_data, cmap=cmap, vmin=0, vmax=1, cbar=False, yticklabels=topic_labels)
ax.set_xticks(xticks)
ax.set_xticklabels(xticklabels, fontproperties=fprop, rotation=45)

plt.xlabel("Year", fontproperties=fprop)
plt.ylabel("Topic ID", fontproperties=fprop)

cbar_ax = plt.gcf().add_axes([0.92, 0.33, 0.02, 0.33])  # [left, bottom, width, height]
cbar = plt.colorbar(ax.collections[0], cax=cbar_ax)
cbar.set_label('Violence Score')

plt.savefig("violence_analysis_by_topic_full.png", format="png", dpi=600, bbox_inches="tight")


# ------------------------
# ------------------------
# MALLET TOPIC MODELING
# ------------------------
# ------------------------

stopwords = [word.strip() for word in open("stopwords-zh.txt", "r").readlines()]
stopwords.extend(["图片","照片","其它","其他"])

min_length = 100
articles_total = 0
total_length_words = 0
total_len_chars = 0
with open('mallet/JFJB_corpus.txt', 'w', encoding='utf-8') as output_file:
    for i, article in enumerate(tqdm(JFJB)):
        text = [token for token in article["text_split"].split() if token not in stopwords and len(token) > 1]
        if len(text) > min_length:
            total_length_words += len(text)
            total_len_chars += sum(len(word) for word in text)
            text = " ".join(text)
            to_write = f"{i}\t{article['article_id']}\t{text}\n"
            output_file.write(to_write)
            articles_total += 1
print(f"Processing complete. Total {articles_total} articles ({total_length_words} words, {total_len_chars} characters) have been written to 'JFJB_corpus.txt'")